// Copyright 2025 Google LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef PRIVACY_PROOFS_ZK_LIB_CIRCUITS_CBOR_PARSER_V2_CBOR_H_
#define PRIVACY_PROOFS_ZK_LIB_CIRCUITS_CBOR_PARSER_V2_CBOR_H_

#include <stddef.h>
#include <stdint.h>

#include <array>
#include <vector>

#include "circuits/cbor_parser_v2/cbor_byte_decoder.h"
#include "circuits/cbor_parser_v2/cbor_constants.h"
#include "circuits/logic/bit_plucker.h"
#include "circuits/logic/counter.h"
#include "circuits/logic/memcmp.h"
#include "circuits/logic/routing.h"
#include "circuits/logic/unary_plucker.h"
#include "util/panic.h"

namespace proofs {
template <class Logic, size_t IndexBits = CborConstants::kIndexBits>
class Cbor {
 public:
  using CounterL = Counter<Logic>;
  using CborBD = CborByteDecoder<Logic>;
  using Field = typename Logic::Field;
  using EltW = typename Logic::EltW;
  using CEltW = typename CounterL::CEltW;
  using BitW = typename Logic::BitW;
  using v8 = typename Logic::v8;
  static constexpr size_t kIndexBits = IndexBits;
  static constexpr size_t kNCounters = CborConstants::kNCounters;
  using counters = std::array<CEltW, kNCounters>;
  using bv_counters = typename Logic::template bitvec<kNCounters>;

  // a bitvector that contains an index into the input
  // (byte) array.
  using vindex = typename Logic::template bitvec<kIndexBits>;

  explicit Cbor(const Logic& l)
      : l_(l), ctr_(l), bd_(l), header_plucker_(l), sel_plucker_(l) {}

  struct position_witness {
    EltW encoded_header;
    EltW encoded_sel;
    CEltW slen_next;
    counters cc_next;

    // In principle we could save witnesses by storing the product of
    // INVPROD_DECODE and INVPROD_PARSE, at the expense of commingling
    // the decoder and parser assertions.  Keep them distinct until we
    // need this optimization.
    EltW invprod_decode;  // inverse of a certain product, see assert_decode()
    EltW invprod_parse;   // inverse of a certain product, see assert_parse()
  };

  //------------------------------------------------------------
  // Canonical order of the witness wires.
  //------------------------------------------------------------
  void witness_wires(size_t n, position_witness pw[/*n*/]) const {
    for (size_t i = 0; i < n; ++i) {
      pw[i].encoded_header = l_.eltw_input();
      pw[i].encoded_sel = l_.eltw_input();
      pw[i].slen_next = ctr_.input();
      for (size_t l = 0; l < kNCounters; ++l) {
        pw[i].cc_next[l] = ctr_.input();
      }
      if (i > 0) {
        pw[i].invprod_decode = l_.eltw_input();
        pw[i].invprod_parse = l_.eltw_input();
      } else {
        // Witnesses at i==0 are unused and undefined.
        // Do not create wires for them.
      }
    }
  }

  //------------------------------------------------------------
  // Decoder (lexer)
  //------------------------------------------------------------
  struct decode {
    // wires generated by the byte decoder given the input.
    typename CborBD::decode bd;

    // wires generated by the lexer from witnesses
    BitW header;
  };

  void assert_decode(size_t n, const decode ds[/*n*/],
                     const position_witness pw[/*n*/]) const {
    const Logic& L = l_;  // shorthand

    // -------------------------------------------------------------
    // Byte decoder didn't fail
    for (size_t i = 0; i < n; ++i) {
      L.assert_implies(&ds[i].header, L.lnot(ds[i].bd.invalid));
    }

    // if COUNT_IS_NEXT_V8 is TRUE in the last position,
    // then the input is invalid.
    L.assert_implies(&ds[n - 1].header, L.lnot(ds[n - 1].bd.count_is_next_v8));

    // -------------------------------------------------------------
    CEltW mone_counter = ctr_.mone();

    // Check the SLEN update equation
    //   SLEN_NEXT[I] = HEADER[I] ? LENGTH[I] : (SLEN[I] - 1)
    // where
    //   SLEN[I] = (I == 0) ? 0 : SLEN_NEXT[I - 1]
    for (size_t i = 0; i < n; ++i) {
      CEltW slen = (i == 0) ? ctr_.as_counter(0) : pw[i - 1].slen_next;
      CEltW slenm1 = ctr_.add(&slen, mone_counter);
      CEltW length = ds[i].bd.length;

      if (i + 1 < n) {
        CEltW len_i =
            ctr_.ite0(&ds[i].bd.length_plus_next_v8, ds[i + 1].bd.as_counter);
        length = ctr_.add(&length, len_i);
      } else {
        // it is an error if, in a header, the length of the token is
        // in a byte past the end of the document
        L.assert_implies(&ds[i].header, L.lnot(ds[i].bd.length_plus_next_v8));
      }

      CEltW slen_next = ctr_.mux(&ds[i].header, &length, slenm1);

      ctr_.assert_eq(&slen_next, pw[i].slen_next);
    }

    // Now check the headers.
    {
      // "The first position is a header"
      L.assert1(ds[0].header);
    }

    {
      EltW one = L.konst(L.one());

      // "\A I : (SLEN_NEXT[I] == 1)  IFF  HEADER[I+1]"
      {
        // "\A I : HEADER[I+1] => (SLEN_NEXT[I] == 1)"
        for (size_t i = 0; i < n; ++i) {
          // There is a header past the end of the document,
          // i.e., SLEN_NEXT[N-1] == 1
          BitW headerp1 = (i + 1 < n) ? ds[i + 1].header : L.bit(1);

          CEltW implies =
              ctr_.ite0(&headerp1, ctr_.add(&pw[i].slen_next, mone_counter));
          ctr_.assert0(implies);
        }
      }
      {
        // "\A I : (SLEN_NEXT[I] == 1) => HEADER[I+1] "
        // Verify via the invertibility of
        //
        //    HEADER[I+1] ? 1 : (SLEN_NEXT[I] - 1)
        //
        // HEADER[N] is implicitly TRUE, so the conditional chooses
        // 1 and we don't need to check I==N-1.
        for (size_t i = 0; i + 1 < n; ++i) {
          CEltW snm1 = ctr_.add(&pw[i].slen_next, mone_counter);
          EltW x = L.mux(&ds[i + 1].header, &one, ctr_.znz_indicator(snm1));
          auto want_one = L.mul(&x, pw[i + 1].invprod_decode);
          L.assert_eq(&want_one, one);
        }
      }
    }
  }

  //------------------------------------------------------------
  // Parser
  //------------------------------------------------------------
  struct parse_output {
    bv_counters sel;
    counters cc_next;
  };

  // "parse" here means produce the parser output, which is trivial
  // because the parser output is given as witnesses.
  void parse(size_t n, parse_output ps[/*n*/], const decode ds[/*n*/],
             const position_witness pw[/*n*/]) const {
    // unpluck the selector and pass the counters through
    for (size_t i = 0; i < n; ++i) {
      ps[i].sel = sel_plucker_.pluck(pw[i].encoded_sel);
      ps[i].cc_next = pw[i].cc_next;
    }
  }

  // Given the current counters CC and the local count COUNT_I, return
  // the new counters COUNTERS_NEXT.  Set *OVERFLOW to TRUE if
  // attempting to update a counter that does not exist.
  //
  // See cbor_witness.h::counters_next() for a perhaps more
  // readable description of this logic.
  counters counters_next(const counters& cc, const bv_counters& sel,
                         const CEltW& count_i, const decode& ds,
                         BitW* overflow) const {
    counters cc_next = cc;

    for (size_t l = 0; l < kNCounters; ++l) {
      // if (header && sel[l]) cc_next[l] = cc[l] - 1;
      BitW header_and_sel = l_.land(&sel[l], ds.header);
      CEltW mone_maybe = ctr_.ite0(&header_and_sel, ctr_.mone());
      cc_next[l] = ctr_.add(&cc[l], mone_maybe);
    }

    for (size_t l = 0; l < kNCounters; ++l) {
      // NEWC = 1       if TAGP
      //        COUNT   if ARRAYP
      //        2*COUNT if MAPP
      CEltW twice_count_i = ctr_.add(&count_i, count_i);
      CEltW one = ctr_.as_counter(1);

      CEltW ifitems = ctr_.mux(&ds.bd.arrayp, /*array:*/ &count_i,
                               /*map:*/ twice_count_i);
      CEltW ifnotitems = ctr_.ite0(&ds.bd.tagp, one);
      CEltW newc = ctr_.mux(&ds.bd.itemsp, &ifitems, ifnotitems);

      BitW header_and_sel = l_.land(&sel[l], ds.header);
      BitW newc_enable =
          l_.land(&header_and_sel, l_.lor(&ds.bd.tagp, ds.bd.itemsp));

      if (l + 1 < kNCounters) {
        cc_next[l + 1] = ctr_.mux(&newc_enable, &newc, cc_next[l + 1]);
      } else {
        // *overflow is always set in the last iteration,
        // so there is no need to initialize it to 0.
        *overflow = newc_enable;
      }
    }

    return cc_next;
  }

  // check that all counters are correctly updated
  void assert_counter_updates(size_t n, const decode ds[/*n*/],
                              const parse_output ps[/*n*/]) const {
    const Logic& L = l_;  // shorthand

    for (size_t i = 0; i < n; ++i) {
      // Finish the decoding of COUNT, which may need
      // lookahead
      CEltW count_i = ds[i].bd.count_as_counter;
      if (i + 1 < n) {
        // if COUNT_IS_NEXT_V8, read COUNT from the next V8
        count_i = ctr_.mux(&ds[i].bd.count_is_next_v8, &ds[i + 1].bd.as_counter,
                           count_i);
      } else {
        // if COUNT_I is actually needed, COUNT_IS_NEXT_V8 must be 0
        L.assert_implies(&ds[i].header, L.lnot(ds[i].bd.count_is_next_v8));
      }

      if (i > 0) {
        BitW overflow;

        // By convention, COUNTERS[I] = COUNTERS_NEXT[I - 1]
        const counters cc = ps[i - 1].cc_next;

        const counters cc_next =
            counters_next(cc, ps[i].sel, count_i, ds[i], &overflow);
        L.assert0(overflow);

        for (size_t l = 0; l < kNCounters; ++l) {
          ctr_.assert_eq(&ps[i].cc_next[l], cc_next[l]);
        }
      }
    }
  }

  void assert_parse(size_t n, const decode ds[/*n*/],
                    const parse_output ps[/*n*/],
                    const position_witness pw[/*n*/]) const {
    const Logic& L = l_;  // shorthand

    assert_counter_updates(n, ds, ps);

    for (size_t i = 0; i < n; ++i) {
      // "The SEL witnesses are mutually exclusive."
      // The bit plucker guarantees that the SEL witnesses
      // are bits, but in principle one could feed an
      // out-of-domain input to the bit plucker that
      // sets more than one bit.
      // Another way to accomplish the same effect would
      // be to range-check the input to the bit plucker.
      for (size_t l = 0; l < kNCounters; ++l) {
        for (size_t m = l + 1; m < kNCounters; ++m) {
          L.assert0(L.land(&ps[i].sel[l], ps[i].sel[m]));
        }
      }

      // "at a header, at least one SEL bit is set"
      auto sum = L.bit(0);
      for (size_t l = 0; l < kNCounters; ++l) {
        // known to be exclusive by the test above
        sum = L.lor_exclusive(&sum, ps[i].sel[l]);
      }
      L.assert_implies(&ds[i].header, sum);
    }

    // "All counters are zero at the end of the input"
    // CC_NEXT[I][L] is the state of the parser at the end
    // of position I, so CC_NEXT[N-1][L] is the final state.
    for (size_t l = 0; l < kNCounters; ++l) {
      ctr_.assert0(ps[n - 1].cc_next[l]);
    }

    // SEL[0][0] is set.  We implicitly define CC_NEXT[-1][L] to make
    // this the correct choice.  Because SEL[I][.] are asserted above
    // to be mutually exclusive, there is no need to test the other
    // selectors.
    L.assert1(ps[0].sel[0]);

    for (size_t i = 0; i + 1 < n; ++i) {
      // "If SEL[I+1][L] is set, then CC_NEXT[I][L] is the nonzero
      // counter of maximal L.  (CC_NEXT[I][L] contains the output
      // counters of stage I, which affect SEL[I+1].)  Here we check
      // maximality:  CC_NEXT[I][J]=0 for J>L.  See below for
      // SEL[I+1][L] => (CC_NEXT[I][L] != 0).
      BitW b = ps[i + 1].sel[0];
      for (size_t l = 1; l < kNCounters; ++l) {
        // b => CC_NEXT[i][l] == 0
        ctr_.assert0(ctr_.ite0(&b, ps[i].cc_next[l]));
        b = L.lor(&b, ps[i + 1].sel[l]);
      }
    }

    // "SEL[I+1][L] => (CC_NEXT[I][L] != 0)"
    // Check via the invertibility of
    //
    //    PROD_{L} SEL[I+1][L] ? CC_NEXT[I][L] : 1
    // We don't need to check I == N-1 because SEL[N][.] is FALSE
    // by definition, and thus the product is the constant 1
    {
      auto one = L.konst(1);
      for (size_t i = 0; i + 1 < n; ++i) {
        EltW p = L.mul(0, kNCounters, [&](size_t l) {
          EltW cc_next = ctr_.znz_indicator(ps[i].cc_next[l]);
          return L.mux(&ps[i + 1].sel[l], &cc_next, one);
        });
        auto want_one = L.mul(&p, pw[i + 1].invprod_parse);
        L.assert_eq(&want_one, one);
      }
    }
  }

  //------------------------------------------------------------
  // "J is the header of a string of length LEN containing BYTES"
  //------------------------------------------------------------
  void assert_text_at(size_t n, const vindex& j, size_t len,
                      const uint8_t bytes[/*len*/],
                      const decode ds[/*n*/]) const {
    const Logic& L = l_;  // shorthand
    const Routing<Logic> R(L);

    // we don't handle long strings
    proofs::check(len < 24, "len < 24");

    assert_header(n, j, ds);

    std::vector<EltW> A(n);
    for (size_t i = 0; i < n; ++i) {
      A[i] = ds[i].bd.as_scalar;
    }

    // shift len+1 bytes, including the header.
    std::vector<EltW> B(len + 1);
    const EltW defaultA = L.konst(256);  // a constant that cannot appear in A[]
    R.shift(j, len + 1, B.data(), n, A.data(), defaultA, /*unroll=*/3);

    size_t expected_header = (3 << 5) + len;
    L.assert_eq(&B[0], L.konst(expected_header));
    for (size_t i = 0; i < len; ++i) {
      auto bi = L.konst(bytes[i]);
      L.assert_eq(&B[i + 1], bi);
    }
  }

  //------------------------------------------------------------
  // "J is a header containing unsigned U."
  //------------------------------------------------------------
  void assert_unsigned_at(size_t n, const vindex& j, uint64_t u,
                          const decode ds[/*n*/]) const {
    // only small u for now
    proofs::check(u < 24, "u < 24");

    size_t expected = (0 << 5) + u;
    assert_atom_at(n, j, l_.konst(expected), ds);
  }

  //------------------------------------------------------------
  // "J is a header containing negative U."  (U >= 0, and
  // CBOR distinguishes 0 from -0 apparently)
  //------------------------------------------------------------
  void assert_negative_at(size_t n, const vindex& j, uint64_t u,
                          const decode ds[/*n*/]) const {
    // only small u for now
    proofs::check(u < 24, "u < 24");

    size_t expected = (1 << 5) + u;
    assert_atom_at(n, j, l_.konst(expected), ds);
  }

  //------------------------------------------------------------
  // "J is a header containing a boolean primitive (0xF4 or 0xF5)."
  //
  //------------------------------------------------------------
  void assert_bool_at(size_t n, const vindex& j, bool val,
                      const decode ds[/*n*/]) const {
    size_t expected = (7 << 5) + (val ? 21 : 20);
    assert_atom_at(n, j, l_.konst(expected), ds);
  }

  // Helps assemble the checks for date assertions.
  void date_helper(size_t n, const vindex& j, const decode ds[/*n*/],
                   std::vector<v8>& B /* size 22 */) const {
    const Logic& L = l_;  // shorthand
    const Routing<Logic> R(L);
    assert_header(n, j, ds);

    std::vector<v8> A(n);
    for (size_t i = 0; i < n; ++i) {
      A[i] = ds[i].bd.as_bits;
    }

    const v8 defaultA =
        L.template vbit<8>(0);  // a constant that cannot appear in A[]
    R.shift(j, 20 + 2, B.data(), n, A.data(), defaultA, /*unroll=*/3);

    // Check for tag: date/time string.
    L.vassert_eq(&B[0], L.template vbit<8>(0xc0));

    // Check for string(20)
    L.vassert_eq(&B[1], L.template vbit<8>(0x74));
  }

  //------------------------------------------------------------
  // "J is a header containing date d < now."  now is 20 bytes
  // in the format 2023-11-01T09:00:00Z
  //------------------------------------------------------------
  void assert_date_before_at(size_t n, const vindex& j, const v8 now[/* 20 */],
                             const decode ds[/*n*/]) const {
    const Logic& L = l_;  // shorthand
    const Memcmp<Logic> CMP(L);
    std::vector<v8> B(20 + 2);
    date_helper(n, j, ds, B);
    auto lt = CMP.lt(20, &B[2], now);
    L.assert1(lt);
  }

  //------------------------------------------------------------
  // "J is a header containing date d > now."  now is 20 bytes in the
  // format 2023-11-01T09:00:00Z
  // ------------------------------------------------------------
  void assert_date_after_at(size_t n, const vindex& j, const v8 now[/* 20 */],
                            const decode ds[/*n*/]) const {
    const Logic& L = l_;  // shorthand
    const Memcmp<Logic> CMP(L);
    std::vector<v8> B(20 + 2);
    date_helper(n, j, ds, B);
    auto lt = CMP.lt(20, now, &B[2]);
    L.assert1(lt);
  }

  //------------------------------------------------------------
  // "J is a header containing represented by the byte EXPECTED in the
  // input."
  //------------------------------------------------------------
  void assert_atom_at(size_t n, const vindex& j, const EltW& expected,
                      const decode ds[/*n*/]) const {
    const Logic& L = l_;  // shorthand
    const Routing<Logic> R(L);

    assert_header(n, j, ds);

    std::vector<EltW> A(n);
    for (size_t i = 0; i < n; ++i) {
      A[i] = ds[i].bd.as_scalar;
    }

    EltW B[1];
    size_t unroll = 3;
    R.shift(j, 1, B, n, A.data(), L.konst(256), unroll);
    L.assert_eq(&B[0], expected);
  }

  //------------------------------------------------------------
  // "Position j contains a header"
  //------------------------------------------------------------
  void assert_header(size_t n, const vindex& j, const decode ds[/*n*/]) const {
    const Logic& L = l_;  // shorthand

    L.vassert_is_bit(j);

    // giant dot product since the veq(j, .) terms are mutually exclusive.
    auto f = [&](size_t i) { return L.land(&ds[i].header, L.veq(j, i)); };
    L.assert1(L.lor_exclusive(0, n, f));
  }

  //------------------------------------------------------------
  // "A map starts at position j"
  //------------------------------------------------------------
  void assert_map_header(size_t n, const vindex& j,
                         const decode ds[/*n*/]) const {
    const Logic& L = l_;  // shorthand

    L.vassert_is_bit(j);

    // giant dot product since the veq(j, .) terms are mutually exclusive.
    auto f = [&](size_t i) {
      auto eq_ji = L.veq(j, i);
      auto dsi = L.land(&ds[i].bd.mapp, ds[i].header);
      return L.land(&eq_ji, dsi);
    };
    L.assert1(L.lor_exclusive(0, n, f));
  }

  //------------------------------------------------------------
  // "Position M starts a map of level LEVEL.  (K, V) are headers
  // representing the J-th pair in that map"
  //------------------------------------------------------------
  void assert_map_entry(size_t n, const vindex& m, size_t level,
                        const vindex& k, const vindex& v, const vindex& j,
                        const decode ds[/*n*/],
                        const parse_output ps[/*n*/]) const {
    const Logic& L = l_;  // shorthand
    const Routing<Logic> R(L);

    assert_map_header(n, m, ds);
    assert_header(n, k, ds);
    assert_header(n, v, ds);

    for (size_t l = 0; l < kNCounters; ++l) {
      // Hack: temporarily treat CEltW as EltW so as to reuse
      // the shifter.
      std::vector<EltW> A(n);
      for (size_t i = 0; i < n; ++i) {
        A[i] = ps[i].cc_next[l].e;
      }

      // Select counters[m], counters[k], and counters[v].
      CEltW cm, ck, cv;

      const size_t unroll = 3;
      R.shift(m, 1, &cm.e, n, A.data(), L.konst(0), unroll);
      R.shift(k, 1, &ck.e, n, A.data(), L.konst(0), unroll);
      R.shift(v, 1, &cv.e, n, A.data(), L.konst(0), unroll);

      if (l <= level) {
        // Counters[L] must agree at the key, value, and root
        // of the map.
        ctr_.assert_eq(&cm, ck);
        ctr_.assert_eq(&cm, cv);
      } else if (l == level + 1) {
        CEltW one = ctr_.as_counter(1);
        CEltW two = ctr_.as_counter(2);
        // LEVEL+1 counters must have the right number of decrements.
        // Specifically, if the counter at the map is N, then the j-th
        // key has N-(2*j+1) and the j-th value has N-(2*j+2)
        CEltW jctr = ctr_.as_counter(j);
        CEltW twoj = ctr_.add(&jctr, jctr);
        ctr_.assert_eq(&cm, ctr_.add(&ck, ctr_.add(&twoj, one)));
        ctr_.assert_eq(&cm, ctr_.add(&cv, ctr_.add(&twoj, two)));
      } else {
        // not sure if this is necessary, but all other counters
        // of CM are supposed to be zero.
        ctr_.assert0(cm);
      }
    }
  }

  //------------------------------------------------------------
  // "JROOT is the first byte of the actual (unpadded) input and
  // all previous bytes are 0"
  //------------------------------------------------------------
  void assert_input_starts_at(size_t n, const vindex& jroot,
                              const vindex& input_len,
                              const decode ds[/*n*/]) const {
    const Logic& L = l_;  // shorthand

    L.assert1(L.vleq(input_len, n));
    L.assert1(L.vlt(jroot, n));
    auto tot = L.vadd(jroot, input_len);
    L.vassert_eq(tot, n);

    for (size_t i = 0; i < n; ++i) {
      L.assert0(L.lmul(&ds[i].bd.as_scalar, L.vlt(i, jroot)));
    }
  }

  //------------------------------------------------------------
  // Utilities
  //------------------------------------------------------------
  // The circuit accepts up to N input positions, of which
  // INPUT_LEN are actual input and the rest are ignored.
  void decode_all(size_t n, decode ds[/*n*/], const v8 in[/*n*/],
                  const position_witness pw[/*n*/]) const {
    for (size_t i = 0; i < n; ++i) {
      ds[i].bd = bd_.decode_one_v8(in[i]);
      auto eh = header_plucker_.pluck(pw[i].encoded_header);
      ds[i].header = eh[0];
    }
  }

  void decode_and_assert_decode(size_t n, decode ds[/*n*/], const v8 in[/*n*/],
                                const position_witness pw[/*n*/]) const {
    decode_all(n, ds, in, pw);
    assert_decode(n, ds, pw);
  }

  void decode_and_assert_decode_and_parse(
      size_t n, decode ds[/*n*/], parse_output ps[/*n*/], const v8 in[/*n*/],
      const position_witness pw[/*n*/]) const {
    decode_and_assert_decode(n, ds, in, pw);
    parse(n, ps, ds, pw);
    assert_parse(n, ds, ps, pw);
  }

 private:
  const Logic& l_;
  const CounterL ctr_;
  const CborBD bd_;
  const BitPlucker<Logic, 1> header_plucker_;
  const UnaryPlucker<Logic, kNCounters> sel_plucker_;
};
}  // namespace proofs

#endif  // PRIVACY_PROOFS_ZK_LIB_CIRCUITS_CBOR_PARSER_V2_CBOR_H_
